# -*- coding: utf-8 -*-
import numpy as np
import collections
from envs.servernode_w_appqueue_w_appinfo_cores import ServerNode as Edge
from envs.servernode_w_totalqueue_cores import ServerNode as Cloud
from envs.applications import *
from envs.channels import *
from envs.constants import *
from envs.utils import *
import gym
from gym import spaces
from gym.utils import seeding

class MEC(gym.Env):
    def __init__(self, task_rate=10, applications=(SPEECH_RECOGNITION, NLP, VR), time_delta=1, use_beta=True, empty_reward=True, cost_type=1, max_episode_steps=1000, time_stamp=0):
        super().__init__()

        self.state_dim= 0
        self.action_dim= 0
        self.clients = dict()
        self.servers = dict()
        self.links = list()
        self.timestamp = time_stamp
        self.silence = True

        self.applications = applications
        self.task_rate = task_rate
        self.reset_info = list()
        self.use_beta = use_beta
        self.empty_reward = empty_reward
        self.cost_type = cost_type
        self.max_episode_steps = max_episode_steps

        channel = HARD

        edge_capability = NUM_EDGE_CORES * NUM_EDGE_SINGLE * GHZ
        cloud_capability = NUM_CLOUD_CORES * NUM_CLOUD_SINGLE * GHZ
        self.reset_info.append((edge_capability, cloud_capability, channel))
        state = self.init_linked_pair(edge_capability, cloud_capability, channel)
        self.obs_dim = state.size-3

        high = np.ones(self.action_dim)
        low = -high
        self.action_space = spaces.Box(low, high)
        self.action_dim = 0

        high = np.inf*np.ones(self.obs_dim)
        low = -high
        self.observation_space = spaces.Box(low, high)

        self.clients = dict()
        self.servers = dict()
        self.links = list()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def init_linked_pair(self, edge_capability, cloud_capability, channel):
        client = self.add_client(edge_capability)
        client.make_application_queues(*self.applications)

        server = self.add_server(cloud_capability)

        self.add_link(client, server, channel)

        state = self._get_obs()

        self.state_dim = len(state)
        self.action_dim += len(client.get_applications())+1
        if self.use_beta:
            self.action_dim *=2
        return state

    def add_client(self, cap):
        client = Edge(cap, True)
        self.clients[client.get_uuid()] = client
        return client

    def add_server(self, cap):
        server = Cloud(cap)
        self.servers[server.get_uuid()] = server
        return server

    def add_link(self, client, server, up_channel, down_channel=None):
        up_channel = Channel(up_channel)
        if not down_channel:
            down_channel = Channel(up_channel)
        else:
            down_channel = Channel(down_channel)

        client.links_to_higher[server.get_uuid()]= {
            'node' : server,
            'channel' : up_channel
        }
        server.links_to_lower[client.get_uuid()] = {
            'node' : client,
            'channel' : down_channel
        }
        self.links.append((client.get_uuid(), server.get_uuid()))
        return

    def get_number_of_apps(self):
        return len(self.applications)

    def __del__(self):
        for k in list(self.clients.keys()):
            del self.clients[k]
        for k in list(self.servers.keys()):
            del self.servers[k]
        del self.links
        del self.applications

    def reset(self, empty_reward=True, rand_start = 0):
        task_rate = self.task_rate
        applications = self.applications
        use_beta = self.use_beta
        cost_type = self.cost_type
        max_episode_steps = self.max_episode_steps
        time_stamp= 0
        empty_reward = self.empty_reward
        self.__del__()

        self.__init__(task_rate, applications, use_beta = use_beta, empty_reward=empty_reward, 
                      cost_type=cost_type, max_episode_steps=max_episode_steps, time_stamp=time_stamp)
        
        for reset_info in self.reset_info:
            self.init_linked_pair(*reset_info)

        _, failed_to_generate, _ = self._step_generation()
        reset_state = self._get_obs(scale=GHZ)
        reset_state = reset_state[:-3]
        reset_state[-1] = 0.0 

        return reset_state

    def _get_obs(self, scale=GHZ):
        edge_state, cloud_state, link_state = list(), list(), list()

        for client in self.clients.values():
            temp_state = client._get_obs(self.timestamp, scale=scale)
            edge_state += temp_state

        state = edge_state
        if self.use_beta:
            for server in self.servers.values():
                temp_state = server._get_obs(self.timestamp, scale=scale)
                cloud_state += temp_state
            for link in self.links:
                link_state.extend([self.clients[link[0]].sample_channel_rate(link[1]),self.servers[link[1]].sample_channel_rate(link[0])])

            state = edge_state + cloud_state
        return np.array(state)

    def step(self, action, use_beta=True, generate=True):
        action = np.clip(action,-1.0,1.0)
        start_state = self._get_obs(scale=GHZ)
        q1=np.array(self.get_edge_qlength(scale=GHZ))
        action_alpha, action_beta, usage_ratio = list(), list(), list()
        if self.use_beta:
            action_alpha = action.flatten()[:int(self.action_dim/2)].reshape(1,-1)
            action_beta = action.flatten()[int(self.action_dim/2):].reshape(1,-1)
            action_beta = softmax_1d(action_beta)
        else:
            action_alpha = action
        action_alpha = softmax_1d(action_alpha)

        used_edge_cpus, inter_state = self._step_alpha(action_alpha)
        used_cloud_cpus, new_state, q_last = self._step_beta(action_beta)

        _, failed_to_generate, _ = self._step_generation()

        new_state = self._get_obs(scale=GHZ)       
        new_state = new_state[:-3]
        new_state[-1]=list(used_cloud_cpus.values())[0]/GHZ/216.0
        cost_vector = self.get_cost(used_edge_cpus, used_cloud_cpus, q1, q_last, cost_type=self.cost_type)

        self.timestamp += 1
        done = 1 if self.timestamp == self.max_episode_steps else 0

        info = {
            "cloud_cpu_used": start_state[-1],
            "v_power": -cost_vector["v_power"]
        }

        return new_state, -cost_vector["queue_costs"], done, info

    def _step_alpha(self, action):
        used_edge_cpus = collections.defaultdict(float)
        action = action.flatten()[:-1].reshape(1,-1)
        for client_id, alpha in list(zip(self.clients.keys(), action)):
            used_edge_cpus[client_id] = self.clients[client_id].do_tasks(alpha)
        state = self._get_obs(scale=GHZ)

        return used_edge_cpus, state
    
    def _step_beta(self, action):    
        used_txs = collections.defaultdict(list)
        tasks_to_be_offloaded = collections.defaultdict(dict)
        used_cloud_cpus = collections.defaultdict(float)
        action = action.flatten()[:-1].reshape(1,-1)

        for client, beta in list(zip(self.clients.values(), action)):
            higher_nodes = client.get_higher_node_ids()
            for higher_node in higher_nodes:
                used_tx, task_to_be_offloaded, failed = client.offload_tasks(beta, higher_node)
                used_txs[higher_node].append(used_tx)
                tasks_to_be_offloaded[higher_node].update(task_to_be_offloaded)
        q_last = self.get_edge_qlength(scale=GHZ)

        for server_id, server in self.servers.items():
            server.offloaded_tasks(tasks_to_be_offloaded[server_id], self.timestamp)
            s2 = self._get_obs(scale=GHZ)
            used_cloud_cpus[server_id] = server.do_tasks()

        state = self._get_obs(scale=GHZ)
        return used_cloud_cpus, state, q_last
    
    def _step_generation(self):
        initial_qlength= self.get_edge_qlength()
        if not self.silence: print("###### random task generation start! ######")
        for client in self.clients.values():
            arrival_size, failed_to_generate = client.random_task_generation(self.task_rate, self.timestamp, *self.applications)
        if not self.silence: print("###### random task generation ends! ######")

        after_qlength = self.get_edge_qlength(scale=GHZ)

        return initial_qlength, failed_to_generate, after_qlength

    def get_edge_qlength(self, scale=1):
        qlengths = list()
        for node in self.clients.values():
            for _, queue in node.get_queue_list():
                qlengths.append( queue.get_length(scale) )
        return qlengths

    def get_cloud_qlength(self, scale=1):
        qlengths = list()
        for node in self.servers.values():
            qlengths.append(node.get_task_queue_length(scale))
        return np.array(qlengths)


    def get_cost(self, used_edge_cpus, used_cloud_cpus, before, after, cost_type, failed_to_offload=0, failed_to_generate=0):
        def compute_cost_fct(cores, cpu_usage):
            return cores*(cpu_usage/400/GHZ/cores)**3

        queue_costs = np.array(after)

        edge_computation_cost = 0
        for used_edge_cpu in used_edge_cpus.values():
            edge_computation_cost += compute_cost_fct(10,used_edge_cpu)

        cloud_payment_cost = 0
        for used_cloud_cpu in used_cloud_cpus.values():
            cloud_payment_cost += compute_cost_fct(54,used_cloud_cpu)

        edge_power = 1000*edge_computation_cost
        cloud_power = 1000*cloud_payment_cost
        v_power = cost_type*(edge_power+cloud_power)

        return {
            "queue_costs": queue_costs,
            "edge_power": edge_power,
            "cloud_power": cloud_power,
            "v_power": v_power
        }